-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type annotate pyro.primitives
& poutine.block_messenger
#3292
Conversation
primitives
& poutine.block_messenger
pyro.primitives
& poutine.block_messenger
|
||
import torch | ||
from torch.distributions import constraints |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see lots of new dependencies here. Historically we've had a number of cyclic dependency issues in Pyro. One thing we might consider to try to avoid cyclid dependencies is to guard these with an if TYPE_CHECKING
:
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch.distributions import contstraints
from pyro.distributions import TorchDistribution
from pyro.params.param_store import ParamStoreDict
from pyro.poutine.runtime import Message
Or maybe we can just use that trick next time we need to fix a cyclic dependency. Either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea. However, in this case it turns out that make docs
fails without TorchDistribution
and ParamStoreDict
imported. And constraints
is used in actual code.
@@ -51,7 +51,7 @@ def __call__(self, sample_shape=torch.Size()): | |||
) | |||
|
|||
@property | |||
def event_dim(self): | |||
def event_dim(self) -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this file I type annotated only methods needed for pyro.primitives
|
||
def effectful( | ||
fn: Optional[Callable[P, T]] = None, type: Optional[str] = None | ||
) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the best I could do at being specific with callable types.
def sample(name, fn, *args, **kwargs): | ||
def sample( | ||
name: str, | ||
fn: TorchDistribution, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct or does fn
can be any callable returning torch.Tensor
?
*args, | ||
obs: Optional[torch.Tensor] = None, | ||
obs_mask: Optional[torch.Tensor] = None, | ||
infer: Optional[Dict[str, Union[str, bool]]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made these args explicit.
@@ -374,7 +406,9 @@ def __init__(self, *args, **kwargs): | |||
|
|||
|
|||
@contextmanager | |||
def plate_stack(prefix, sizes, rightmost_dim=-1): | |||
def plate_stack( | |||
prefix: str, sizes: Sequence[int], rightmost_dim: int = -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring says sizes
is iterable, however, iterable is not reversible.
@@ -462,7 +498,7 @@ def module(name, nn_module, update_module_params=False): | |||
param_name | |||
] = target_state_dict[_name] | |||
else: | |||
nn_module._parameters[mod_name] = target_state_dict[_name] | |||
nn_module._parameters[mod_name] = target_state_dict[_name] # type: ignore[assignment] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn_module._parameters
's type is nn.Parameter
and target_state_dict[_name]
is torch.Tensor
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
No description provided.